#include <torch/all.h>

#include "es_sm100_mxfp8_blockscaled_group_quant.cuh"

void es_sm100_mxfp8_blockscaled_grouped_quant(
    const torch::Tensor& input,
    const torch::Tensor& problem_sizes,
    const torch::Tensor& expert_offsets,
    const torch::Tensor& blockscale_offsets,
    torch::Tensor& quant_output,
    torch::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
  TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
  TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
  TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
  TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");

  auto groups = problem_sizes.size(0);
  TORCH_CHECK(
      expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
      "expert_offsets must be 1D and have size equal to the number of groups");
  TORCH_CHECK(
      blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
      "blockscale_offsets must be 1D and have size equal to the number of groups");

  auto stream = at::cuda::getCurrentCUDAStream();
  if (input.dtype() == torch::kBFloat16) {
    expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__nv_bfloat16>(
        input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
  } else if (input.dtype() == torch::kFloat16) {
    expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__half>(
        input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
  } else {
    TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
  }
#else
  TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_mm for current device");
#endif
}
